{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "60a7e08e-93c6-4370-9778-3bb102dce78b",
   "metadata": {},
   "source": [
    "Copyright (c) Meta Platforms, Inc. and affiliates."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3081cd8f-f6f9-4a1a-8c36-8a857b0c3b03",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f9f3240f-0354-4802-b8b5-9070930fc957",
   "metadata": {},
   "source": [
    "# CoTracker: It is Better to Track Together\n",
    "This is a demo for <a href=\"https://co-tracker.github.io/\">CoTracker</a>, a model that can track any point in a video."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36ff1fd0-572e-47fb-8221-1e73ac17cfd1",
   "metadata": {},
   "source": [
    "<img src=\"https://www.robots.ox.ac.uk/~nikita/storage/cotracker/bmx-bumps.gif\" alt=\"Logo\" width=\"50%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88c6db31",
   "metadata": {},
   "source": [
    "Don't forget to turn on GPU support if you're running this demo in Colab. \n",
    "\n",
    "**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**\n",
    "\n",
    "Let's install dependencies for Colab:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278876a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !git clone https://github.com/facebookresearch/co-tracker\n",
    "# %cd co-tracker\n",
    "# !pip install -e .\n",
    "# !pip install opencv-python einops timm matplotlib moviepy flow_vis\n",
    "# !mkdir checkpoints\n",
    "# %cd checkpoints\n",
    "# !wget https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ..\n",
    "import os\n",
    "import torch\n",
    "\n",
    "from base64 import b64encode\n",
    "from cotracker.utils.visualizer import Visualizer, read_video_from_path\n",
    "from IPython.display import HTML\n",
    "import torch.nn.functional as F\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7894bd2d-2099-46fa-8286-f0c56298ecd1",
   "metadata": {},
   "source": [
    "Read a video from CO3D:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# video = read_video_from_path('./assets/output.mp4')\n",
    "video = read_video_from_path(\"./assets/breakdance.mp4\")\n",
    "video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()\n",
    "# video = F.interpolate(video[0], scale_factor=0.6, mode='bilinear', align_corners=True)[None]\n",
    "print(video.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_video(video_path):\n",
    "    video_file = open(video_path, \"r+b\").read()\n",
    "    video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "    return HTML(f\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls><source src=\"{video_url}\"></video>\"\"\")\n",
    " \n",
    "show_video(\"./assets/breakdance.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f89ae18-54d0-4384-8a79-ca9247f5f31a",
   "metadata": {},
   "source": [
    "Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d59ac40b-bde8-46d4-bd57-4ead939f22ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cotracker.predictor import CoTrackerPredictor\n",
    "\n",
    "model = CoTrackerPredictor(\n",
    "    checkpoint=os.path.join(\n",
    "        './checkpoints/cotracker_pretrain/cotracker_stride_4_wind_8.pth'\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f2a4485",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    video = video.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8398155-6dae-4ff0-95f3-dbb52ac70d20",
   "metadata": {},
   "source": [
    "Track points sampled on a regular grid of size 30\\*30 on the first frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17fcaae9-7b3c-474c-977a-cce08a09d580",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "import cv2\n",
    "fps = 1\n",
    "\n",
    "input_mask = './assets/breakdance.png'\n",
    "segm_mask = np.array(Image.open(input_mask))\n",
    "_, T, C, H, W = video.shape\n",
    "vidLen = video.shape[1]\n",
    "idx = torch.range(0, vidLen-1, fps).long()\n",
    "video=video[:, idx]\n",
    "if len(segm_mask.shape)==3:\n",
    "    segm_mask = (segm_mask.mean(axis=-1)>0)\n",
    "segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)\n",
    "pred_tracks, pred_visibility = model(video, grid_size=50, backward_tracking=False, segm_mask=torch.from_numpy(segm_mask)[None, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50a58521-a9ba-4f8b-be02-cfdaf79613a2",
   "metadata": {},
   "source": [
    "Visualize and save the result: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e793ce0-7b77-46ca-a629-155a6a146000",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis = Visualizer(save_dir='./videos', pad_value=0, tracks_leave_trace=10)\n",
    "vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename=input_mask.split('/')[-1].split('.')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d0733ba-8fe1-4cd4-b963-2085202fba13",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video(\"./videos/teaser_pred_track.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b033fa31",
   "metadata": {},
   "source": [
    "## SpatialTracker Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6d0eaa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------- import the basic packages ------------\n",
    "%cd ..\n",
    "import os\n",
    "import torch\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "fps = 1\n",
    "from base64 import b64encode\n",
    "from cotracker.utils.visualizer import Visualizer, read_video_from_path\n",
    "from IPython.display import HTML\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "# ---------- read the video ------------\n",
    "video = read_video_from_path(\"./assets/fan.mp4\")\n",
    "video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()\n",
    "video = F.interpolate(video[0], scale_factor=1.0, mode='bilinear', align_corners=True)[None]\n",
    "_, T, C, H, W = video.shape\n",
    "\n",
    "def show_video(video_path):\n",
    "    video_file = open(video_path, \"r+b\").read()\n",
    "    video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "    return HTML(f\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls><source src=\"{video_url}\"></video>\"\"\")\n",
    " \n",
    "show_video(\"./assets/fan.mp4\")\n",
    "\n",
    "# ---------- run the spatialtracker ------------\n",
    "from spatracker_v1.predictor import CoTrackerPredictor\n",
    "from spatracker_v1.zoeDepth.models.builder import build_model\n",
    "from spatracker_v1.zoeDepth.utils.config import get_config\n",
    "from spatracker_v1.utils.visualizer import Visualizer, read_video_from_path\n",
    "video = video.cuda()\n",
    "vidLen = video.shape[1]\n",
    "idx = torch.range(0, vidLen-1, fps).long()\n",
    "video=video[:, idx]\n",
    "# init the monocular depth perception\n",
    "# conf = get_config(\"zoedepth\", \"infer\", config_version=\"kitti\")\n",
    "conf = get_config(\"zoedepth_nk\", \"infer\")\n",
    "DEVICE = f\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "model_zoe_nk = build_model(conf).to(DEVICE)\n",
    "model_zoe_nk.eval()\n",
    "\n",
    "model = CoTrackerPredictor(\n",
    "    checkpoint=os.path.join(\n",
    "        '/home/xyx/home/codes/co_tracker/checkpoints/spv1/model_cotracker_199375.pth'\n",
    "    )\n",
    ")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    video = video.cuda()\n",
    "\n",
    "input_mask = './assets/fan.png'\n",
    "segm_mask = np.array(Image.open(input_mask))\n",
    "if len(segm_mask.shape)==3:\n",
    "    segm_mask = (segm_mask[..., :3].mean(axis=-1)>0).astype(np.uint8)\n",
    "segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)\n",
    "pred_tracks, pred_visibility = model(video, grid_size=50, backward_tracking=False, depth_predictor=model_zoe_nk, segm_mask=torch.from_numpy(segm_mask)[None, None])\n",
    "vis = Visualizer(save_dir='./videos', pad_value=0, grayscale=True, tracks_leave_trace=5, fps=15)\n",
    "vis.visualize(video=video, tracks=pred_tracks[..., :2], visibility=pred_visibility, filename='teaser')\n",
    "show_video(\"./videos/teaser_pred_track.mp4\")\n",
    "\n",
    "# ---------- visualize the 4D point cloud ------------\n",
    "import plotly.graph_objects as go\n",
    "from ipywidgets import interact, IntSlider\n",
    "import ipywidgets as widgets\n",
    "xyzt = pred_tracks[0].cpu().numpy()   # T x N x 3\n",
    "intr = np.array([[W, 0.0, W//2],\n",
    "                 [0.0, W, H//2],\n",
    "                 [0.0, 0.0, 1.0]])\n",
    "xyztVis = xyzt.copy()\n",
    "xyztVis[..., 2] = 1.0\n",
    "# xyztVis[..., 0] = 2*(xyztVis[..., 0] / W - 0.5)\n",
    "# xyztVis[..., 1] = 2*(xyztVis[..., 1] / H - 0.5)\n",
    "\n",
    "xyztVis = np.linalg.inv(intr[None, ...]) @ xyztVis.reshape(-1, 3, 1) # (TN) 3 1\n",
    "xyztVis = xyztVis.reshape(T, -1, 3) # T N 3\n",
    "xyztVis[..., 2] *= xyzt[..., 2]\n",
    "scatter = go.Scatter3d(\n",
    "    x=xyztVis[0, :, 0],\n",
    "    y=xyztVis[0, :, 1],\n",
    "    z=xyztVis[0, :, 2],\n",
    "    mode='markers',\n",
    "    marker=dict(\n",
    "        size=3,\n",
    "        color='blue',\n",
    "        opacity=0.8\n",
    "    )\n",
    ")\n",
    "data = [scatter]\n",
    "\n",
    "# layout = go.Layout(\n",
    "#     scene=dict(\n",
    "#         aspectmode='manual',  \n",
    "#         xaxis=dict(title='X'),\n",
    "#         yaxis=dict(title='Y'),\n",
    "#         zaxis=dict(title='Z')\n",
    "#     ),\n",
    "#     uirevision=True\n",
    "# )\n",
    "\n",
    "layout = go.Layout(\n",
    "    scene=dict(\n",
    "        xaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 x 轴范围并禁用自动调整\n",
    "        yaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 y 轴范围并禁用自动调整\n",
    "        zaxis=dict(range=[-0.5, 30], autorange=False),  # 设置 z 轴范围并禁用自动调整\n",
    "        aspectmode='manual',\n",
    "        aspectratio=dict(x=1, y=1, z=1),\n",
    "    )\n",
    ")\n",
    "\n",
    "# fig = go.Figure(data=data, layout=layout)\n",
    "\n",
    "fig = go.FigureWidget()\n",
    "scatter = fig.add_scatter3d(x=xyztVis[0, :, 0],\n",
    "                             y=xyztVis[0, :, 1], z=xyztVis[0, :, 2],\n",
    "                                 mode='markers',\n",
    "                                 marker=dict(\n",
    "                                 size=1,  \n",
    "                                 color='blue'  \n",
    "                                ))\n",
    "\n",
    "slider = IntSlider(min=0, max=T-1, step=1, value=0)\n",
    "\n",
    "def update(frame):\n",
    "    fig.data[0].x = xyztVis[frame, :, 0]\n",
    "    fig.data[0].y = xyztVis[frame, :, 1]\n",
    "    fig.data[0].z = xyztVis[frame, :, 2]\n",
    "fig.layout = layout\n",
    "widgets.interact(update, frame=slider)\n",
    "display(fig, slider)\n",
    "\n",
    "# def update(frame):\n",
    "#     fig.data[0].x = xyztVis[frame, :, 0]\n",
    "#     fig.data[0].y = xyztVis[frame, :, 1]\n",
    "#     fig.data[0].z = xyztVis[frame, :, 2]\n",
    "#     print(frame)\n",
    "#     # display(fig)\n",
    "\n",
    "# def quit(obj):\n",
    "#     print(\"quit\")\n",
    "#     return\n",
    "    \n",
    "\n",
    "# slider = IntSlider(min=0, max=T-1, step=1, value=0)\n",
    "# btn=widgets.Button (description=\"quit\")\n",
    "# display(btn)\n",
    "# display(fig, slider)\n",
    "\n",
    "# btn.on_click(quit)\n",
    "# interact(update, frame=slider)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "431635fd",
   "metadata": {},
   "source": [
    "## SpatialTracker Final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590b0469",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ---------- import the basic packages ------------\n",
    "%cd ..\n",
    "import os\n",
    "import torch\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "from base64 import b64encode\n",
    "from cotracker.utils.visualizer import Visualizer, read_video_from_path\n",
    "from IPython.display import HTML\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "fps = 1.\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "# ---------- read the video ------------\n",
    "video = read_video_from_path(\"./assets/cheetan.mp4\")\n",
    "video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()\n",
    "video = F.interpolate(video[0], scale_factor=0.8, mode='bilinear', align_corners=True)[None]\n",
    "_, T, C, H, W = video.shape\n",
    "\n",
    "def show_video(video_path):\n",
    "    video_file = open(video_path, \"r+b\").read()\n",
    "    video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n",
    "    return HTML(f\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls><source src=\"{video_url}\"></video>\"\"\")\n",
    " \n",
    "show_video(\"./assets/cheetan.mp4\")\n",
    "\n",
    "# ---------- run the spatialtracker ------------\n",
    "from spatracker1.predictor import CoTrackerPredictor\n",
    "from spatracker1.zoeDepth.models.builder import build_model\n",
    "from spatracker1.zoeDepth.utils.config import get_config\n",
    "from spatracker1.utils.visualizer import Visualizer, read_video_from_path\n",
    "video = video.cuda()\n",
    "# init the monocular depth perception\n",
    "# conf = get_config(\"zoedepth\", \"infer\", config_version=\"kitti\")\n",
    "conf = get_config(\"zoedepth_nk\", \"infer\")\n",
    "DEVICE = f\"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "model_zoe_nk = build_model(conf).to(DEVICE)\n",
    "model_zoe_nk.eval()\n",
    "\n",
    "model = CoTrackerPredictor(\n",
    "    checkpoint=os.path.join(\n",
    "        './checkpoints/spv1_noise_new/model_cotracker_061875.pth'\n",
    "        # './checkpoints/SpatialTracker/model_cotracker_199375.pth'\n",
    "    )\n",
    ")\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.cuda()\n",
    "    video = video.cuda()\n",
    "    \n",
    "import cv2\n",
    "input_mask = './assets/cheetan.png'\n",
    "segm_mask = np.array(Image.open(input_mask))\n",
    "if len(segm_mask.shape)==3:\n",
    "    segm_mask = segm_mask.mean(axis=-1)\n",
    "segm_mask = cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)\n",
    "vidLen = video.shape[1]\n",
    "idx = torch.range(0, vidLen-1, fps).long()\n",
    "video=video[:, idx]\n",
    "pred_tracks, pred_visibility = model(video, grid_size=60, backward_tracking=False,\n",
    "                                     depth_predictor=model_zoe_nk, grid_query_frame=0,\n",
    "                                     segm_mask=torch.from_numpy(segm_mask)[None, None], add_new_pts=False)\n",
    "\n",
    "vis = Visualizer(save_dir='./videos', pad_value=0, tracks_leave_trace=10)\n",
    "vis.visualize(video=video, tracks=pred_tracks[..., :2], visibility=pred_visibility, filename='teaser')\n",
    "show_video(\"./videos/teaser_pred_track.mp4\")\n",
    "\n",
    "# ---------- visualize the 4D point cloud ------------\n",
    "import plotly.graph_objects as go\n",
    "from ipywidgets import interact, IntSlider\n",
    "import ipywidgets as widgets\n",
    "xyzt = pred_tracks[0].cpu().numpy()   # T x N x 3\n",
    "intr = np.array([[W, 0.0, W//2],\n",
    "                 [0.0, W, H//2],\n",
    "                 [0.0, 0.0, 1.0]])\n",
    "xyztVis = xyzt.copy()\n",
    "xyztVis[..., 2] = 1.0\n",
    "# xyztVis[..., 0] = 2*(xyztVis[..., 0] / W - 0.5)\n",
    "# xyztVis[..., 1] = 2*(xyztVis[..., 1] / H - 0.5)\n",
    "\n",
    "xyztVis = np.linalg.inv(intr[None, ...]) @ xyztVis.reshape(-1, 3, 1) # (TN) 3 1\n",
    "xyztVis = xyztVis.reshape(T, -1, 3) # T N 3\n",
    "xyztVis[..., 2] *= xyzt[..., 2]\n",
    "scatter = go.Scatter3d(\n",
    "    x=xyztVis[0, :, 0],\n",
    "    y=xyztVis[0, :, 1],\n",
    "    z=xyztVis[0, :, 2],\n",
    "    mode='markers',\n",
    "    marker=dict(\n",
    "        size=3,\n",
    "        color='blue',\n",
    "        opacity=0.8\n",
    "    )\n",
    ")\n",
    "data = [scatter]\n",
    "\n",
    "# layout = go.Layout(\n",
    "#     scene=dict(\n",
    "#         aspectmode='manual',  \n",
    "#         xaxis=dict(title='X'),\n",
    "#         yaxis=dict(title='Y'),\n",
    "#         zaxis=dict(title='Z')\n",
    "#     ),\n",
    "#     uirevision=True\n",
    "# )\n",
    "\n",
    "layout = go.Layout(\n",
    "    scene=dict(\n",
    "        xaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 x 轴范围并禁用自动调整\n",
    "        yaxis=dict(range=[-1.5, 1.5], autorange=False),  # 设置 y 轴范围并禁用自动调整\n",
    "        zaxis=dict(range=[-0.5, 12], autorange=False),  # 设置 z 轴范围并禁用自动调整\n",
    "        aspectmode='manual',\n",
    "        aspectratio=dict(x=1, y=1, z=1),\n",
    "    )\n",
    ")\n",
    "\n",
    "# fig = go.Figure(data=data, layout=layout)\n",
    "\n",
    "fig = go.FigureWidget()\n",
    "scatter = fig.add_scatter3d(x=xyztVis[0, :, 0],\n",
    "                             y=xyztVis[0, :, 1], z=xyztVis[0, :, 2],\n",
    "                                 mode='markers',\n",
    "                                 marker=dict(\n",
    "                                 size=1,  \n",
    "                                 color='blue'  \n",
    "                                ))\n",
    "\n",
    "# 1 T N 3\n",
    "pred_tracks2d = pred_tracks[0][:, :, :2]\n",
    "S1, N1, _ = pred_tracks2d.shape\n",
    "video2d = video[0] # T C H W\n",
    "H1, W1 = video[0].shape[-2:] \n",
    "pred_tracks2dNm = pred_tracks2d.clone()\n",
    "pred_tracks2dNm[..., 0] = 2*(pred_tracks2dNm[..., 0] / W1 - 0.5)\n",
    "pred_tracks2dNm[..., 1] = 2*(pred_tracks2dNm[..., 1] / H1 - 0.5)\n",
    "color_interp = torch.nn.functional.grid_sample(video2d, pred_tracks2dNm[:,:,None,:], align_corners=True)\n",
    "# T N 1 3 \n",
    "color_interp = color_interp[:, :, :, 0].permute(0,2,1).cpu().numpy()\n",
    "\n",
    "colored_pts = np.concatenate([xyztVis, color_interp], axis=-1)\n",
    "\n",
    "\n",
    "slider = IntSlider(min=0, max=T-1, step=1, value=0)\n",
    "\n",
    "def update(frame):\n",
    "    fig.data[0].x = xyztVis[frame, :, 0]\n",
    "    fig.data[0].y = xyztVis[frame, :, 1]\n",
    "    fig.data[0].z = xyztVis[frame, :, 2]\n",
    "fig.layout = layout\n",
    "widgets.interact(update, frame=slider)\n",
    "display(fig, slider)\n",
    "\n",
    "np.save('./assets/cheetan.npy', xyztVis)\n",
    "np.save('./assets/cheetan.npy', colored_pts)\n",
    "\n",
    "# def update(frame):\n",
    "#     fig.data[0].x = xyztVis[frame, :, 0]\n",
    "#     fig.data[0].y = xyztVis[frame, :, 1]\n",
    "#     fig.data[0].z = xyztVis[frame, :, 2]\n",
    "#     print(frame)\n",
    "#     # display(fig)\n",
    "\n",
    "# def quit(obj):\n",
    "#     print(\"quit\")\n",
    "#     return\n",
    "    \n",
    "\n",
    "# slider = IntSlider(min=0, max=T-1, step=1, value=0)\n",
    "# btn=widgets.Button (description=\"quit\")\n",
    "# display(btn)\n",
    "# display(fig, slider)\n",
    "\n",
    "# btn.on_click(quit)\n",
    "# interact(update, frame=slider)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c880f3ca-cf42-4f64-9df6-a0e8de6561dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_size = 30\n",
    "grid_query_frame = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd58820-7b23-469e-9b6d-5fa81257981f",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a85a1d-dce0-4e6b-9f7a-aaf31ade0600",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis = Visualizer(save_dir='./videos', pad_value=100)\n",
    "vis.visualize(\n",
    "    video=video,\n",
    "    tracks=pred_tracks,\n",
    "    visibility=pred_visibility,\n",
    "    filename='grid_query_20',\n",
    "    query_frame=grid_query_frame)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce0fb5b8-d249-4f4e-b59a-51b4f03972c4",
   "metadata": {},
   "source": [
    "Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0b01d51-9222-472b-a714-188c38d83ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video(\"./videos/grid_query_20_pred_track.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10baad8f-0cb8-4118-9e69-3fb24575715c",
   "metadata": {},
   "source": [
    "### Tracking forward **and backward** from the frame number x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8409e2f7-9e4e-4228-b198-56a64e2260a7",
   "metadata": {},
   "source": [
    "CoTracker is an online algorithm and tracks points only in one direction. However, we can also run it backward from the queried point to track in both directions: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "506233dc-1fb3-4a3c-b9eb-5cbd5df49128",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_size = 30\n",
    "grid_query_frame = 20"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "495b5fb4-9050-41fe-be98-d757916d0812",
   "metadata": {},
   "source": [
    "Let's activate backward tracking:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "677cf34e-6c6a-49e3-a21b-f8a4f718f916",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)\n",
    "vis.visualize(\n",
    "    video=video,\n",
    "    tracks=pred_tracks,\n",
    "    visibility=pred_visibility,\n",
    "    filename='grid_query_20_backward',\n",
    "    query_frame=grid_query_frame)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "585a0afa-2cfc-4a07-a6f0-f65924b9ebce",
   "metadata": {},
   "source": [
    "As you can see, we are now tracking points queried in the middle from the first frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d64ab0-7e92-4238-8e7d-178652fc409c",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video(\"./videos/grid_query_20_backward_pred_track.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb55fb01-6d8e-4e06-9346-8b2e9ef489c2",
   "metadata": {},
   "source": [
    "## Regular grid + Segmentation mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e93a6b0a-b173-46fa-a6d2-1661ae6e6779",
   "metadata": {},
   "source": [
    "Let's now sample points on a grid and filter them with a segmentation mask.\n",
    "This allows us to track points sampled densely on an object because we consume less GPU memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b759548d-1eda-473e-9c90-99e5d3197e20",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "grid_size = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14ae8a8b-fec7-40d1-b6f2-10e333b75db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_mask = './assets/apple_mask.png'\n",
    "segm_mask = np.array(Image.open(input_mask))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e3a1520-64bf-4a0d-b6e9-639430e31940",
   "metadata": {},
   "source": [
    "That's a segmentation mask for the first frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d2efd4e-22df-4833-b9a0-a0763d59ee22",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b42dce24-7952-4660-8298-4c362d6913cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])\n",
    "vis = Visualizer(\n",
    "    save_dir='./videos',\n",
    "    pad_value=100,\n",
    "    linewidth=2,\n",
    ")\n",
    "vis.visualize(\n",
    "    video=video,\n",
    "    tracks=pred_tracks,\n",
    "    visibility=pred_visibility,\n",
    "    filename='segm_grid')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a386308-0d20-4ba3-bbb9-98ea79823a47",
   "metadata": {},
   "source": [
    "We are now only tracking points on the object (and around):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1810440f-00f4-488a-a174-36be05949e42",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video(\"./videos/segm_grid_pred_track.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a63e89e4-8890-4e1b-91ec-d5dfa3f93309",
   "metadata": {},
   "source": [
    "## Dense Tracks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ae764d8-db7c-41c2-a712-1876e7b4372d",
   "metadata": {},
   "source": [
    "### Tracking forward **and backward** from the frame number x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dde3237-ecad-4c9b-b100-28b1f1b3cbe6",
   "metadata": {},
   "source": [
    "CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "379557d9-80ea-4316-91df-4da215193b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "video.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6db5cc7-351d-4d9e-9b9d-3a40f05b077a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "video_interp = F.interpolate(video[0], [100,180], mode=\"bilinear\")[None]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ba32cb3-97dc-46f5-b2bd-b93a094dc819",
   "metadata": {},
   "source": [
    "The video now has a much lower resolution:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0918f246-5556-43b8-9f6d-88013d5a487e",
   "metadata": {},
   "outputs": [],
   "source": [
    "video_interp.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc7d3a2c-5e87-4c8d-ad10-1f9c6d2ffbed",
   "metadata": {},
   "source": [
    "Again, let's track points in both directions. This will only take a couple of minutes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b852606-5229-4abd-b166-496d35da1009",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tracks, pred_visibility = model(video_interp, grid_query_frame=20, backward_tracking=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4143ab14-810e-4e65-93f1-5775957cf4da",
   "metadata": {},
   "source": [
    "Visualization with an optical flow color encoding:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5394b0ba-1fc7-4843-91d5-6113a6e86bdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis = Visualizer(\n",
    "    save_dir='./videos',\n",
    "    pad_value=20,\n",
    "    linewidth=1,\n",
    "    mode='optical_flow'\n",
    ")\n",
    "vis.visualize(\n",
    "    video=video_interp,\n",
    "    tracks=pred_tracks,\n",
    "    visibility=pred_visibility,\n",
    "    query_frame=grid_query_frame,\n",
    "    filename='dense')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9113c2ac-4d25-4ef2-8951-71a1c1be74dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video(\"./videos/dense_pred_track.mp4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e9bce0-382b-4d18-9316-7f92093ada1d",
   "metadata": {},
   "source": [
    "That's all, now you can use CoTracker in your projects!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54e0ba0c-b532-46a9-af6f-9508de689dd2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "DynNeRFusion",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
